{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# RAG Workflow with Reranking\n",
    "\n",
    "This notebook walks through setting up a `Workflow` to perform basic RAG with reranking."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install -U llama-index"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "os.environ[\"OPENAI_API_KEY\"] = \"sk-proj-...\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### [Optional] Set up observability with Llamatrace\n",
    "\n",
    "Set up tracing to visualize each step in the workflow."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%pip install \"openinference-instrumentation-llama-index>=3.0.0\" \"opentelemetry-proto>=1.12.0\" opentelemetry-exporter-otlp opentelemetry-sdk"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from opentelemetry.sdk import trace as trace_sdk\n",
    "from opentelemetry.sdk.trace.export import SimpleSpanProcessor\n",
    "from opentelemetry.exporter.otlp.proto.http.trace_exporter import (\n",
    "    OTLPSpanExporter as HTTPSpanExporter,\n",
    ")\n",
    "from openinference.instrumentation.llama_index import LlamaIndexInstrumentor\n",
    "\n",
    "\n",
    "# Add Phoenix API Key for tracing\n",
    "PHOENIX_API_KEY = \"<YOUR-PHOENIX-API-KEY>\"\n",
    "os.environ[\"OTEL_EXPORTER_OTLP_HEADERS\"] = f\"api_key={PHOENIX_API_KEY}\"\n",
    "\n",
    "# Add Phoenix\n",
    "span_phoenix_processor = SimpleSpanProcessor(\n",
    "    HTTPSpanExporter(endpoint=\"https://app.phoenix.arize.com/v1/traces\")\n",
    ")\n",
    "\n",
    "# Add them to the tracer\n",
    "tracer_provider = trace_sdk.TracerProvider()\n",
    "tracer_provider.add_span_processor(span_processor=span_phoenix_processor)\n",
    "\n",
    "# Instrument the application\n",
    "LlamaIndexInstrumentor().instrument(tracer_provider=tracer_provider)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!mkdir -p data\n",
    "!wget --user-agent \"Mozilla\" \"https://arxiv.org/pdf/2307.09288.pdf\" -O \"data/llama2.pdf\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Since workflows are async first, this all runs fine in a notebook. If you were running in your own code, you would want to use `asyncio.run()` to start an async event loop if one isn't already running.\n",
    "\n",
    "```python\n",
    "async def main():\n",
    "    <async code>\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    import asyncio\n",
    "    asyncio.run(main())\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Designing the Workflow\n",
    "\n",
    "RAG + Reranking consists of some clearly defined steps\n",
    "1. Indexing data, creating an index\n",
    "2. Using that index + a query to retrieve relevant text chunks\n",
    "3. Rerank the text retrieved text chunks using the original query\n",
    "4. Synthesizing a final response\n",
    "\n",
    "With this in mind, we can create events and workflow steps to follow this process!\n",
    "\n",
    "### The Workflow Events\n",
    "\n",
    "To handle these steps, we need to define a few events:\n",
    "1. An event to pass retrieved nodes to the reranker\n",
    "2. An event to pass reranked nodes to the synthesizer\n",
    "\n",
    "The other steps will use the built-in `StartEvent` and `StopEvent` events."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index.core.workflow import Event\n",
    "from llama_index.core.schema import NodeWithScore\n",
    "\n",
    "\n",
    "class RetrieverEvent(Event):\n",
    "    \"\"\"Result of running retrieval\"\"\"\n",
    "\n",
    "    nodes: list[NodeWithScore]\n",
    "\n",
    "\n",
    "class RerankEvent(Event):\n",
    "    \"\"\"Result of running reranking on retrieved nodes\"\"\"\n",
    "\n",
    "    nodes: list[NodeWithScore]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### The Workflow Itself\n",
    "\n",
    "With our events defined, we can construct our workflow and steps. \n",
    "\n",
    "Note that the workflow automatically validates itself using type annotations, so the type annotations on our steps are very helpful!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index.core import SimpleDirectoryReader, VectorStoreIndex\n",
    "from llama_index.core.response_synthesizers import CompactAndRefine\n",
    "from llama_index.core.postprocessor.llm_rerank import LLMRerank\n",
    "from llama_index.core.workflow import (\n",
    "    Context,\n",
    "    Workflow,\n",
    "    StartEvent,\n",
    "    StopEvent,\n",
    "    step,\n",
    ")\n",
    "\n",
    "from llama_index.llms.openai import OpenAI\n",
    "from llama_index.embeddings.openai import OpenAIEmbedding\n",
    "\n",
    "\n",
    "class RAGWorkflow(Workflow):\n",
    "    @step\n",
    "    async def ingest(self, ctx: Context, ev: StartEvent) -> StopEvent | None:\n",
    "        \"\"\"Entry point to ingest a document, triggered by a StartEvent with `dirname`.\"\"\"\n",
    "        dirname = ev.get(\"dirname\")\n",
    "        if not dirname:\n",
    "            return None\n",
    "\n",
    "        documents = SimpleDirectoryReader(dirname).load_data()\n",
    "        index = VectorStoreIndex.from_documents(\n",
    "            documents=documents,\n",
    "            embed_model=OpenAIEmbedding(model_name=\"text-embedding-3-small\"),\n",
    "        )\n",
    "        return StopEvent(result=index)\n",
    "\n",
    "    @step\n",
    "    async def retrieve(\n",
    "        self, ctx: Context, ev: StartEvent\n",
    "    ) -> RetrieverEvent | None:\n",
    "        \"Entry point for RAG, triggered by a StartEvent with `query`.\"\n",
    "        query = ev.get(\"query\")\n",
    "        index = ev.get(\"index\")\n",
    "\n",
    "        if not query:\n",
    "            return None\n",
    "\n",
    "        print(f\"Query the database with: {query}\")\n",
    "\n",
    "        # store the query in the global context\n",
    "        await ctx.store.set(\"query\", query)\n",
    "\n",
    "        # get the index from the global context\n",
    "        if index is None:\n",
    "            print(\"Index is empty, load some documents before querying!\")\n",
    "            return None\n",
    "\n",
    "        retriever = index.as_retriever(similarity_top_k=2)\n",
    "        nodes = await retriever.aretrieve(query)\n",
    "        print(f\"Retrieved {len(nodes)} nodes.\")\n",
    "        return RetrieverEvent(nodes=nodes)\n",
    "\n",
    "    @step\n",
    "    async def rerank(self, ctx: Context, ev: RetrieverEvent) -> RerankEvent:\n",
    "        # Rerank the nodes\n",
    "        ranker = LLMRerank(\n",
    "            choice_batch_size=5, top_n=3, llm=OpenAI(model=\"gpt-4o-mini\")\n",
    "        )\n",
    "        print(await ctx.store.get(\"query\", default=None), flush=True)\n",
    "        new_nodes = ranker.postprocess_nodes(\n",
    "            ev.nodes, query_str=await ctx.store.get(\"query\", default=None)\n",
    "        )\n",
    "        print(f\"Reranked nodes to {len(new_nodes)}\")\n",
    "        return RerankEvent(nodes=new_nodes)\n",
    "\n",
    "    @step\n",
    "    async def synthesize(self, ctx: Context, ev: RerankEvent) -> StopEvent:\n",
    "        \"\"\"Return a streaming response using reranked nodes.\"\"\"\n",
    "        llm = OpenAI(model=\"gpt-4o-mini\")\n",
    "        summarizer = CompactAndRefine(llm=llm, streaming=True, verbose=True)\n",
    "        query = await ctx.store.get(\"query\", default=None)\n",
    "\n",
    "        response = await summarizer.asynthesize(query, nodes=ev.nodes)\n",
    "        return StopEvent(result=response)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "And thats it! Let's explore the workflow we wrote a bit.\n",
    "\n",
    "- We have two entry points (the steps that accept `StartEvent`)\n",
    "- The steps themselves decide when they can run\n",
    "- The workflow context is used to store the user query\n",
    "- The nodes are passed around, and finally a streaming response is returned"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Run the Workflow!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "w = RAGWorkflow()\n",
    "\n",
    "# Ingest the documents\n",
    "index = await w.run(dirname=\"data\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Query the database with: How was Llama2 trained?\n",
      "Retrieved 2 nodes.\n",
      "How was Llama2 trained?\n",
      "Reranked nodes to 2\n",
      "Llama 2 was trained through a multi-step process that began with pretraining using publicly available online sources. This was followed by the creation of an initial version of Llama 2-Chat through supervised fine-tuning. The model was then iteratively refined using Reinforcement Learning with Human Feedback (RLHF) methodologies, which included rejection sampling and Proximal Policy Optimization (PPO). \n",
      "\n",
      "During pretraining, the model utilized an optimized auto-regressive transformer architecture, incorporating robust data cleaning, updated data mixes, and training on a significantly larger dataset of 2 trillion tokens. The training process also involved increased context length and the use of grouped-query attention (GQA) to enhance inference scalability.\n",
      "\n",
      "The training employed the AdamW optimizer with specific hyperparameters, including a cosine learning rate schedule and gradient clipping. The models were pretrained on Meta’s Research SuperCluster and internal production clusters, utilizing NVIDIA A100 GPUs."
     ]
    }
   ],
   "source": [
    "# Run a query\n",
    "result = await w.run(query=\"How was Llama2 trained?\", index=index)\n",
    "async for chunk in result.async_response_gen():\n",
    "    print(chunk, end=\"\", flush=True)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "llama-index-cDlKpkFt-py3.11",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
